In [1]:
import tensorflow as tf 
import numpy as np

# Import MNIST Data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)


Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

In [8]:
# Parameters 
learning_rate = 1e-4
epoch = 10
epoch_size = 100
batch_size = 50
training_iters = epoch * epoch_size


# Network Parameters 
n_input = 28
n_step = 28
n_hidden = 128
n_output = 10

Network Function


In [23]:
def network(x):
    x = tf.transpose(x, [1,0,2])
    x = tf.reshape(x, [-1, n_input])
    x = tf.split(x, n_step, 0)
    
    lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
    
    outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
    
    W = tf.Variable(tf.truncated_normal([n_hidden, n_output], stddev=0.01), name = 'softmax_w')
    b = tf.Variable(tf.constant(value=0.1, shape=[n_output]), name = 'softmax_b')

    return tf.matmul(outputs, W) + b

Define the Model


In [24]:
# Create / Reset the graph
tf.reset_default_graph()

x = tf.placeholder(tf.float32, shape=[None, n_step, n_input])
y = tf.placeholder(tf.float32, shape=[None, n_output])

pred = network(x)

# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)


---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
/home/ppyht2/.local/lib/python3.5/site-packages/tensorflow/python/framework/common_shapes.py in _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn)
    670           graph_def_version, node_def_str, input_shapes, input_tensors,
--> 671           input_tensors_as_shapes, status)
    672   except errors.InvalidArgumentError as err:

/usr/lib/python3.5/contextlib.py in __exit__(self, type, value, traceback)
     65             try:
---> 66                 next(self.gen)
     67             except StopIteration:

/home/ppyht2/.local/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py in raise_exception_on_not_ok_status()
    465           compat.as_text(pywrap_tensorflow.TF_Message(status)),
--> 466           pywrap_tensorflow.TF_GetCode(status))
    467   finally:

InvalidArgumentError: Shape must be rank 2 but is rank 3 for 'MatMul' (op: 'MatMul') with input shapes: [28,?,128], [128,10].

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-24-e877c4965511> in <module>()
      5 y = tf.placeholder(tf.float32, shape=[None, n_output])
      6 
----> 7 pred = network(x)
      8 
      9 # Define loss and optimizer

<ipython-input-23-27800ab18c68> in network(x)
     11     weights = {'out': tf.Variable(tf.random_normal([n_hidden, n_output]))}
     12     biases = {'out': tf.Variable(tf.random_normal([n_output]))}
---> 13     return tf.matmul(outputs, weights['out']) + biases['out']

/home/ppyht2/.local/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py in matmul(a, b, transpose_a, transpose_b, adjoint_a, adjoint_b, a_is_sparse, b_is_sparse, name)
   1763     else:
   1764       return gen_math_ops._mat_mul(
-> 1765           a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
   1766 
   1767 

/home/ppyht2/.local/lib/python3.5/site-packages/tensorflow/python/ops/gen_math_ops.py in _mat_mul(a, b, transpose_a, transpose_b, name)
   1452   """
   1453   result = _op_def_lib.apply_op("MatMul", a=a, b=b, transpose_a=transpose_a,
-> 1454                                 transpose_b=transpose_b, name=name)
   1455   return result
   1456 

/home/ppyht2/.local/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py in apply_op(self, op_type_name, name, **keywords)
    761         op = g.create_op(op_type_name, inputs, output_types, name=scope,
    762                          input_types=input_types, attrs=attr_protos,
--> 763                          op_def=op_def)
    764         if output_structure:
    765           outputs = op.outputs

/home/ppyht2/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in create_op(self, op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_shapes, compute_device)
   2327                     original_op=self._default_original_op, op_def=op_def)
   2328     if compute_shapes:
-> 2329       set_shapes_for_outputs(ret)
   2330     self._add_op(ret)
   2331     self._record_op_seen_by_control_dependencies(ret)

/home/ppyht2/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in set_shapes_for_outputs(op)
   1715       shape_func = _call_cpp_shape_fn_and_require_op
   1716 
-> 1717   shapes = shape_func(op)
   1718   if shapes is None:
   1719     raise RuntimeError(

/home/ppyht2/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py in call_with_requiring(op)
   1665 
   1666   def call_with_requiring(op):
-> 1667     return call_cpp_shape_fn(op, require_shape_fn=True)
   1668 
   1669   _call_cpp_shape_fn_and_require_op = call_with_requiring

/home/ppyht2/.local/lib/python3.5/site-packages/tensorflow/python/framework/common_shapes.py in call_cpp_shape_fn(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn)
    608     res = _call_cpp_shape_fn_impl(op, input_tensors_needed,
    609                                   input_tensors_as_shapes_needed,
--> 610                                   debug_python_shape_fn, require_shape_fn)
    611     if not isinstance(res, dict):
    612       # Handles the case where _call_cpp_shape_fn_impl calls unknown_shape(op).

/home/ppyht2/.local/lib/python3.5/site-packages/tensorflow/python/framework/common_shapes.py in _call_cpp_shape_fn_impl(op, input_tensors_needed, input_tensors_as_shapes_needed, debug_python_shape_fn, require_shape_fn)
    674       missing_shape_fn = True
    675     else:
--> 676       raise ValueError(err.message)
    677 
    678   if missing_shape_fn:

ValueError: Shape must be rank 2 but is rank 3 for 'MatMul' (op: 'MatMul') with input shapes: [28,?,128], [128,10].

In [ ]: